{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Recommendation engine using collaborating filtering on Movielens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "\n",
    "from fastai.learner import *\n",
    "from fastai.column_data import *\n",
    "from fastai.imports import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = '.'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "collaborating filter.ipynb  ml-latest-small.zip  movielens.ipynb    tmp\r\n",
      "ml-latest-small\t\t    models\t\t ratings_small.csv\r\n"
     ]
    }
   ],
   "source": [
    "! ls ."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "ratings = pd.read_csv('ratings_small.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>userId</th>\n",
       "      <th>movieId</th>\n",
       "      <th>rating</th>\n",
       "      <th>timestamp</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>31</td>\n",
       "      <td>2.5</td>\n",
       "      <td>1260759144</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>1029</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1260759179</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>1061</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1260759182</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>1129</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1260759185</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>1172</td>\n",
       "      <td>4.0</td>\n",
       "      <td>1260759205</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   userId  movieId  rating   timestamp\n",
       "0       1       31     2.5  1260759144\n",
       "1       1     1029     3.0  1260759179\n",
       "2       1     1061     3.0  1260759182\n",
       "3       1     1129     2.0  1260759185\n",
       "4       1     1172     4.0  1260759205"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ratings.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(100004, 4)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ratings.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " There are no NAs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_users=int(ratings.userId.nunique())\n",
    "n_movies=int(ratings.movieId.nunique())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n_users =  671 || n_movies =  9066\n"
     ]
    }
   ],
   "source": [
    "print(\"n_users = \",n_users, \"||\", \"n_movies = \", n_movies )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's create a cross-tab for better visualization of user ids and item ids."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = ratings.groupby('userId')['rating'].count()\n",
    "topg = g.sort_values(ascending = False)[:15]\n",
    "\n",
    "i = ratings.groupby('movieId')['rating'].count()\n",
    "topi = i.sort_values(ascending = False)[:15]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>movieId</th>\n",
       "      <th>1</th>\n",
       "      <th>110</th>\n",
       "      <th>260</th>\n",
       "      <th>296</th>\n",
       "      <th>318</th>\n",
       "      <th>356</th>\n",
       "      <th>480</th>\n",
       "      <th>527</th>\n",
       "      <th>589</th>\n",
       "      <th>593</th>\n",
       "      <th>608</th>\n",
       "      <th>1196</th>\n",
       "      <th>1198</th>\n",
       "      <th>1270</th>\n",
       "      <th>2571</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>userId</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>2.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>73</th>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.5</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>4.5</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>212</th>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.5</td>\n",
       "      <td>4.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>3.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>213</th>\n",
       "      <td>3.0</td>\n",
       "      <td>2.5</td>\n",
       "      <td>5.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>2.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>4.0</td>\n",
       "      <td>2.5</td>\n",
       "      <td>2.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>294</th>\n",
       "      <td>4.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>3.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.5</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>311</th>\n",
       "      <td>3.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>4.5</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.5</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.5</td>\n",
       "      <td>2.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>4.5</td>\n",
       "      <td>4.5</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>380</th>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>452</th>\n",
       "      <td>3.5</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>468</th>\n",
       "      <td>4.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>3.5</td>\n",
       "      <td>3.5</td>\n",
       "      <td>3.5</td>\n",
       "      <td>3.0</td>\n",
       "      <td>2.5</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>3.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>3.5</td>\n",
       "      <td>3.0</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>509</th>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.5</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>4.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>547</th>\n",
       "      <td>3.5</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>2.5</td>\n",
       "      <td>2.0</td>\n",
       "      <td>3.5</td>\n",
       "      <td>3.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>564</th>\n",
       "      <td>4.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>580</th>\n",
       "      <td>4.0</td>\n",
       "      <td>4.5</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.5</td>\n",
       "      <td>4.0</td>\n",
       "      <td>3.5</td>\n",
       "      <td>3.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.5</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.5</td>\n",
       "      <td>4.0</td>\n",
       "      <td>3.5</td>\n",
       "      <td>3.0</td>\n",
       "      <td>4.5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>624</th>\n",
       "      <td>5.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>3.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "movieId  1     110   260   296   318   356   480   527   589   593   608   \\\n",
       "userId                                                                      \n",
       "15        2.0   3.0   5.0   5.0   2.0   1.0   3.0   4.0   4.0   5.0   5.0   \n",
       "30        4.0   5.0   4.0   5.0   5.0   5.0   4.0   5.0   4.0   4.0   5.0   \n",
       "73        5.0   4.0   4.5   5.0   5.0   5.0   4.0   5.0   3.0   4.5   4.0   \n",
       "212       3.0   5.0   4.0   4.0   4.5   4.0   3.0   5.0   3.0   4.0   NaN   \n",
       "213       3.0   2.5   5.0   NaN   NaN   2.0   5.0   NaN   4.0   2.5   2.0   \n",
       "294       4.0   3.0   4.0   NaN   3.0   4.0   4.0   4.0   3.0   NaN   NaN   \n",
       "311       3.0   3.0   4.0   3.0   4.5   5.0   4.5   5.0   4.5   2.0   4.0   \n",
       "380       4.0   5.0   4.0   5.0   4.0   5.0   4.0   NaN   4.0   5.0   4.0   \n",
       "452       3.5   4.0   4.0   5.0   5.0   4.0   5.0   4.0   4.0   5.0   5.0   \n",
       "468       4.0   3.0   3.5   3.5   3.5   3.0   2.5   NaN   NaN   3.0   4.0   \n",
       "509       3.0   5.0   5.0   5.0   4.0   4.0   3.0   5.0   2.0   4.0   4.5   \n",
       "547       3.5   NaN   NaN   5.0   5.0   2.0   3.0   5.0   NaN   5.0   5.0   \n",
       "564       4.0   1.0   2.0   5.0   NaN   3.0   5.0   4.0   5.0   5.0   5.0   \n",
       "580       4.0   4.5   4.0   4.5   4.0   3.5   3.0   4.0   4.5   4.0   4.5   \n",
       "624       5.0   NaN   5.0   5.0   NaN   3.0   3.0   NaN   3.0   5.0   4.0   \n",
       "\n",
       "movieId  1196  1198  1270  2571  \n",
       "userId                           \n",
       "15        5.0   4.0   5.0   5.0  \n",
       "30        4.0   5.0   5.0   3.0  \n",
       "73        5.0   5.0   5.0   4.5  \n",
       "212       NaN   3.0   3.0   5.0  \n",
       "213       5.0   3.0   3.0   4.0  \n",
       "294       4.0   4.5   4.0   4.5  \n",
       "311       3.0   4.5   4.5   4.0  \n",
       "380       4.0   NaN   3.0   5.0  \n",
       "452       4.0   4.0   4.0   2.0  \n",
       "468       3.0   3.5   3.0   3.0  \n",
       "509       5.0   5.0   3.0   4.5  \n",
       "547       2.5   2.0   3.5   3.5  \n",
       "564       5.0   5.0   3.0   3.0  \n",
       "580       4.0   3.5   3.0   4.5  \n",
       "624       5.0   5.0   5.0   2.0  "
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# gettings ratings of top users and top items\n",
    "\n",
    "join1 = ratings.join(topg, on='userId', how = 'inner', rsuffix='_r')\n",
    "join1 = join1.join(topi, on='movieId', how = 'inner', rsuffix = '_r')\n",
    "\n",
    "pd.crosstab(join1.userId, join1.movieId, join1.rating, aggfunc=np.sum)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Collaborative filtering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_indx = get_cv_idxs(len(ratings))  # index for validation set\n",
    "wd = 2e-4 # weight decay\n",
    "n_factors = 50 # n_factors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data loader\n",
    "cf = CollabFilterDataset.from_csv(path, 'ratings_small.csv', 'userId', 'movieId', 'rating')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = cf.get_learner(n_factors, val_indx, bs=64, opt_fn=optim.Adam)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "065d1fd5a51945359526689db634fd06",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in Jupyter Notebook or JupyterLab, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another notebook frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       0.7727   0.80396]                                    \n",
      "[ 1.       0.77782  0.77585]                                    \n",
      "[ 2.       0.58389  0.76542]                                    \n",
      "\n"
     ]
    }
   ],
   "source": [
    "learn.fit(1e-2,2, wds = wd, cycle_len=1, cycle_mult=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We got .76"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Collaborating filter from scratch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "u_uniq = ratings.userId.unique()\n",
    "user2idx = {o:i for i,o in enumerate(u_uniq)}\n",
    "ratings.userId = ratings.userId.apply(lambda x: user2idx[x])\n",
    "\n",
    "m_uniq = ratings.movieId.unique()\n",
    "movie2idx = {o:i for i,o in enumerate(m_uniq)}\n",
    "ratings.movieId = ratings.movieId.apply(lambda x: movie2idx[x])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(671, 9066)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n_users, n_movies"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`nn.Embedding` creates a lookup table that stores embeddings of a fixed dictionary and size. So word embeddings once stored can be retrieved using indices. After making `embeddings`, we get free `u.weights` which are correspondings weights of ebeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_indx = get_cv_idxs(len(ratings))  # index for validation set\n",
    "wd = 2e-4 # weight decay\n",
    "n_factors = 50 # n_factors i.e. 1 dimension of embeddings (random)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.5, 5.0)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "min_rating,max_rating = ratings.rating.min(),ratings.rating.max()\n",
    "min_rating,max_rating"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_emb(ni,nf):\n",
    "    e = nn.Embedding(ni, nf)\n",
    "    e.weight.data.uniform_(-0.01,0.01)\n",
    "    #e.weight.data.normal_(0,0.003)\n",
    "\n",
    "    return e"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = ratings.drop(['rating'],axis=1)\n",
    "y = ratings['rating'].astype(np.float32)\n",
    "\n",
    "data = ColumnarModelData.from_data_frame(path, val_indx, x, y, ['userId', 'movieId'], 64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# nh = dimension of hidden linear layer\n",
    "# p1 = dropout1\n",
    "# p2 = dropout2\n",
    "\n",
    "class EmbeddingNet(nn.Module):\n",
    "    def __init__(self, n_users, _n_movies, nh = 10, p1 = 0.05, p2= 0.5):\n",
    "        super().__init__()\n",
    "        (self.u, self.m, self.ub, self.mb) = [get_emb(*o) for o in [\n",
    "            (n_users, n_factors), (n_movies, n_factors),\n",
    "            (n_users,1), (n_movies,1)\n",
    "        ]]\n",
    "        \n",
    "        self.lin1 = nn.Linear(n_factors*2, nh)  # bias is True by default\n",
    "        self.lin2 = nn.Linear(nh, 1)\n",
    "        self.drop1 = nn.Dropout(p = p1)\n",
    "        self.drop2 = nn.Dropout(p = p2)\n",
    "    \n",
    "    def forward(self, cats, conts): # forward pass i.e.  dot product of vector from movie embedding matrixx\n",
    "                                    # and vector from user embeddings matrix\n",
    "        \n",
    "        # torch.cat : concatenates both embedding matrix to make more columns, same rows i.e. n_factors*2, n : rows\n",
    "        # u(users) is doing lookup for indexed mentioned in users\n",
    "        # users has indexes to lookup in embedding matrix. \n",
    "        \n",
    "        users,movies = cats[:,0],cats[:,1]\n",
    "        u2,m2 = self.u(users) , self.m(movies)\n",
    "       \n",
    "        x = self.drop1(torch.cat([u2,m2], 1)) # drop initialized weights\n",
    "        x = self.drop2(F.relu(self.lin1(x))) # drop 1st linear + nonlinear wt\n",
    "        r = F.sigmoid(self.lin2(x)) * (max_rating - min_rating) + min_rating               \n",
    "        return r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "wd=1e-5\n",
    "model = EmbeddingNet(n_users, n_movies)\n",
    "model = model.cuda()\n",
    "opt = optim.Adam(model.parameters(), 1e-3, weight_decay=wd) # got parameter() for free , lr = 1e-3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "EmbeddingNet (\n",
       "  (u): Embedding(671, 50)\n",
       "  (m): Embedding(9066, 50)\n",
       "  (ub): Embedding(671, 1)\n",
       "  (mb): Embedding(9066, 1)\n",
       "  (lin1): Linear (100 -> 10)\n",
       "  (lin2): Linear (10 -> 1)\n",
       "  (drop1): Dropout (p = 0.05)\n",
       "  (drop2): Dropout (p = 0.5)\n",
       ")"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2f45cf8456bc4101a27367f9c26d9f6a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in Jupyter Notebook or JupyterLab, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another notebook frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       0.74293  0.79247]                                    \n",
      "[ 1.       0.74748  0.79483]                                    \n",
      "[ 2.       0.75364  0.79638]                                    \n",
      "\n"
     ]
    }
   ],
   "source": [
    "fit(model, data, 3, opt, F.mse_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                        \r"
     ]
    }
   ],
   "source": [
    "# from tqdm import tqdm as tqdm_cls\n",
    "\n",
    "# inst = tqdm_cls._instances\n",
    "# for i in range(len(inst)): inst.pop().close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_lrs(opt, 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ea1cd58e12574699969fd0040c18f1c2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/html": [
       "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n",
       "<p>\n",
       "  If you're reading this message in Jupyter Notebook or JupyterLab, it may mean\n",
       "  that the widgets JavaScript is still loading. If this message persists, it\n",
       "  likely means that the widgets JavaScript library is either not installed or\n",
       "  not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n",
       "  Widgets Documentation</a> for setup instructions.\n",
       "</p>\n",
       "<p>\n",
       "  If you're reading this message in another notebook frontend (for example, a static\n",
       "  rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n",
       "  it may mean that your frontend doesn't currently support widgets.\n",
       "</p>\n"
      ],
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.       0.79631  0.78994]                                    \n",
      "[ 1.       0.78677  0.79127]                                    \n",
      "[ 2.      0.7614  0.7906]                                       \n",
      "\n"
     ]
    }
   ],
   "source": [
    "fit(model, data, 3, opt, F.mse_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Surprise package"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "collaborating filter.ipynb  ml-latest-small.zip  movielens.ipynb    tmp\r\n",
      "ml-latest-small\t\t    models\t\t ratings_small.csv\r\n"
     ]
    }
   ],
   "source": [
    "! ls ."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "196\t242\t3\t881250949\r\n",
      "186\t302\t3\t891717742\r\n",
      "22\t377\t1\t878887116\r\n",
      "244\t51\t2\t880606923\r\n",
      "166\t346\t1\t886397596\r\n"
     ]
    }
   ],
   "source": [
    "! head -5 '/home/ubuntu/.surprise_data/ml-100k/ml-100k/u.data'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "userId,movieId,rating,timestamp\r",
      "\r\n",
      "1,31,2.5,1260759144\r",
      "\r\n",
      "1,1029,3.0,1260759179\r",
      "\r\n",
      "1,1061,3.0,1260759182\r",
      "\r\n",
      "1,1129,2.0,1260759185\r",
      "\r\n"
     ]
    }
   ],
   "source": [
    "! head -5 'ratings_small.csv'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from surprise import Reader, Dataset\n",
    "# Define the format\n",
    "\n",
    "reader = Reader(line_format='user item rating timestamp', sep='\\t')\n",
    "# Load the data from the file using the reader format\n",
    "\n",
    "data = Dataset.load_from_file('/home/ubuntu/.surprise_data/ml-100k/ml-100k/u.data', reader=reader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>userId</th>\n",
       "      <th>movieId</th>\n",
       "      <th>rating</th>\n",
       "      <th>timestamp</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>31</td>\n",
       "      <td>2.5</td>\n",
       "      <td>1260759144</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>1029</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1260759179</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   userId  movieId  rating   timestamp\n",
       "0       1       31     2.5  1260759144\n",
       "1       1     1029     3.0  1260759179"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ratings[:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "ratings_dict = {'itemID': list(ratings.movieId),\n",
    "                'userID': list(ratings.userId),\n",
    "                'rating': list(ratings.rating)}\n",
    "df = pd.DataFrame(ratings_dict)\n",
    "\n",
    "# A reader is still needed but only the rating_scale param is requiered.\n",
    "reader = Reader(rating_scale=(0.5, 5.0))\n",
    "# The columns must correspond to user id, item id and ratings (in that order).\n",
    "data = Dataset.load_from_df(df[['userID', 'itemID', 'rating']], reader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split data into 5 folds\n",
    "\n",
    "data.split(n_folds=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from surprise import SVD, evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from surprise import GridSearch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'lr_all': 0.002, 'reg_all': 0.4}, {'lr_all': 0.002, 'reg_all': 0.6}, {'lr_all': 0.005, 'reg_all': 0.4}, {'lr_all': 0.005, 'reg_all': 0.6}]\n",
      "------------\n",
      "Parameters combination 1 of 4\n",
      "params:  {'lr_all': 0.002, 'reg_all': 0.4}\n",
      "------------\n",
      "Mean RMSE: 0.9133\n",
      "------------\n",
      "------------\n",
      "Parameters combination 2 of 4\n",
      "params:  {'lr_all': 0.002, 'reg_all': 0.6}\n",
      "------------\n",
      "Mean RMSE: 0.9214\n",
      "------------\n",
      "------------\n",
      "Parameters combination 3 of 4\n",
      "params:  {'lr_all': 0.005, 'reg_all': 0.4}\n",
      "------------\n",
      "Mean RMSE: 0.9031\n",
      "------------\n",
      "------------\n",
      "Parameters combination 4 of 4\n",
      "params:  {'lr_all': 0.005, 'reg_all': 0.6}\n",
      "------------\n",
      "Mean RMSE: 0.9121\n",
      "------------\n"
     ]
    }
   ],
   "source": [
    "param_grid = {'lr_all': [0.002, 0.005],\n",
    "              'reg_all': [0.4, 0.6]}\n",
    "grid_search = GridSearch(SVD, param_grid, measures=['RMSE'])\n",
    "grid_search.evaluate(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating RMSE of algorithm SVD.\n",
      "\n",
      "------------\n",
      "Fold 1\n",
      "RMSE: 0.8990\n",
      "------------\n",
      "Fold 2\n",
      "RMSE: 0.8983\n",
      "------------\n",
      "Fold 3\n",
      "RMSE: 0.8941\n",
      "------------\n",
      "Fold 4\n",
      "RMSE: 0.8962\n",
      "------------\n",
      "Fold 5\n",
      "RMSE: 0.8962\n",
      "------------\n",
      "------------\n",
      "Mean RMSE: 0.8967\n",
      "------------\n",
      "------------\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "CaseInsensitiveDefaultDict(list,\n",
       "                           {'rmse': [0.89895181594737417,\n",
       "                             0.89831051013903251,\n",
       "                             0.89405859774725671,\n",
       "                             0.89621812893141306,\n",
       "                             0.89617318551492264]})"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "algo = SVD()\n",
    "evaluate(algo, data, measures=['RMSE'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As noticed above best RMSE from SVD is still higher than default result from fast.ai's neural net version for collaborative filtering using embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Let's try `KNN` algorithm also. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from surprise import KNNBasic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating RMSE, MAE of algorithm KNNBasic.\n",
      "\n",
      "------------\n",
      "Fold 1\n",
      "Computing the msd similarity matrix...\n",
      "Done computing similarity matrix.\n",
      "RMSE: 0.9662\n",
      "MAE:  0.7645\n",
      "------------\n",
      "Fold 2\n",
      "Computing the msd similarity matrix...\n",
      "Done computing similarity matrix.\n",
      "RMSE: 0.9834\n",
      "MAE:  0.7787\n",
      "------------\n",
      "Fold 3\n",
      "Computing the msd similarity matrix...\n",
      "Done computing similarity matrix.\n",
      "RMSE: 0.9802\n",
      "MAE:  0.7744\n",
      "------------\n",
      "Fold 4\n",
      "Computing the msd similarity matrix...\n",
      "Done computing similarity matrix.\n",
      "RMSE: 0.9812\n",
      "MAE:  0.7728\n",
      "------------\n",
      "Fold 5\n",
      "Computing the msd similarity matrix...\n",
      "Done computing similarity matrix.\n",
      "RMSE: 0.9804\n",
      "MAE:  0.7735\n",
      "------------\n",
      "------------\n",
      "Mean RMSE: 0.9783\n",
      "Mean MAE : 0.7728\n",
      "------------\n",
      "------------\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "CaseInsensitiveDefaultDict(list,\n",
       "                           {'mae': [0.76447687302283862,\n",
       "                             0.77871336218916276,\n",
       "                             0.77444253761129189,\n",
       "                             0.77277756247233054,\n",
       "                             0.77353073380751081],\n",
       "                            'rmse': [0.96618541819639647,\n",
       "                             0.98337516247695278,\n",
       "                             0.98018440899082937,\n",
       "                             0.98120591146396685,\n",
       "                             0.98038668816669572]})"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "algo = KNNBasic()\n",
    "evaluate(algo, data, measures=['RMSE', 'MAE'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## NMF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from surprise import NMF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating RMSE of algorithm NMF.\n",
      "\n",
      "------------\n",
      "Fold 1\n",
      "RMSE: 0.9476\n",
      "------------\n",
      "Fold 2\n",
      "RMSE: 0.9449\n",
      "------------\n",
      "Fold 3\n",
      "RMSE: 0.9479\n",
      "------------\n",
      "Fold 4\n",
      "RMSE: 0.9494\n",
      "------------\n",
      "Fold 5\n",
      "RMSE: 0.9450\n",
      "------------\n",
      "------------\n",
      "Mean RMSE: 0.9469\n",
      "------------\n",
      "------------\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "CaseInsensitiveDefaultDict(list,\n",
       "                           {'rmse': [0.9475771765522677,\n",
       "                             0.94487435132530351,\n",
       "                             0.94786484545358385,\n",
       "                             0.94936598409066575,\n",
       "                             0.94501542053063314]})"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "algo = NMF()\n",
    "evaluate(algo, data, measures=['RMSE'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## cosine distance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "ratings = pd.read_csv('ratings_small.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>userId</th>\n",
       "      <th>movieId</th>\n",
       "      <th>rating</th>\n",
       "      <th>timestamp</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>31</td>\n",
       "      <td>2.5</td>\n",
       "      <td>1260759144</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>1029</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1260759179</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   userId  movieId  rating   timestamp\n",
       "0       1       31     2.5  1260759144\n",
       "1       1     1029     3.0  1260759179"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ratings[:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "ratings2 = ratings.copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "col = ['movieId', 'userId']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "for c in col:\n",
    "    ratings2[c].replace({val: i for i, val in enumerate(ratings2[c].unique())}, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>userId</th>\n",
       "      <th>movieId</th>\n",
       "      <th>rating</th>\n",
       "      <th>timestamp</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2.5</td>\n",
       "      <td>1260759144</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1260759179</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   userId  movieId  rating   timestamp\n",
       "0       0        0     2.5  1260759144\n",
       "1       0        1     3.0  1260759179"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ratings2[:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_users=int(ratings2.userId.nunique())\n",
    "n_items=int(ratings2.movieId.nunique())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n_users =  671 || n_items =  9066\n"
     ]
    }
   ],
   "source": [
    "print(\"n_users = \",n_users, \"||\", \"n_items = \", n_items )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/anaconda3/lib/python3.6/site-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "from sklearn import cross_validation as cv\n",
    "train_data, test_data = cv.train_test_split(ratings2, test_size=0.25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Create two user-item matrices, one for training and another for testing\n",
    "train_data_matrix = np.zeros((n_users, n_items))\n",
    "for line in train_data.itertuples():\n",
    "    train_data_matrix[line[1]-1, line[2]-1] = line[3]\n",
    "    \n",
    "test_data_matrix = np.zeros((n_users, n_items))\n",
    "for line in test_data.itertuples():\n",
    "    test_data_matrix[line[1]-1, line[2]-1] = line[3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(ratings, similarity, type='user'):\n",
    "    if type == 'user':\n",
    "        mean_user_rating = ratings.mean(axis=1)\n",
    "        #You use np.newaxis so that mean_user_rating has same format as ratings\n",
    "        ratings_diff = (ratings - mean_user_rating[:, np.newaxis])\n",
    "        pred = mean_user_rating[:, np.newaxis] + similarity.dot(ratings_diff) / np.array([np.abs(similarity).sum(axis=1)]).T\n",
    "    elif type == 'item':\n",
    "        pred = ratings.dot(similarity) / np.array([np.abs(similarity).sum(axis=1)])\n",
    "    return pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics.pairwise import pairwise_distances\n",
    "user_similarity = pairwise_distances(train_data_matrix, metric='cosine')\n",
    "item_similarity = pairwise_distances(train_data_matrix.T, metric='cosine')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "item_prediction = predict(train_data_matrix, item_similarity, type='item')\n",
    "user_prediction = predict(train_data_matrix, user_similarity, type='user')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import mean_squared_error\n",
    "from math import sqrt\n",
    "def mse(prediction, ground_truth):\n",
    "    prediction = prediction[ground_truth.nonzero()].flatten()\n",
    "    ground_truth = ground_truth[ground_truth.nonzero()].flatten()\n",
    "    return mean_squared_error(prediction, ground_truth)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "User-based CF MSE:  11.3668710905\n",
      "Item-based CF MSE:  12.8400786831\n"
     ]
    }
   ],
   "source": [
    "print('User-based CF MSE: ' , str(mse(user_prediction, test_data_matrix)))\n",
    "print('Item-based CF MSE: ' , str(mse(item_prediction, test_data_matrix)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## plot comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZAAAAELCAYAAAD3HtBMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAFY5JREFUeJzt3XuYXVV9xvH3dSZFLoqtDJaKOt6l\noEIyBhFBBEQwlIpKgYIKoqOMIojaassjJNrSei+2QVNuKhhUKpYSiyAkeAECMxBIAtQL5RKrZvLg\nhUQfLpNf/9jrJIfhnDmZldnn+v08Tx722Xufs36bOTPvWWvvs7YjQgAATNeTWl0AAKAzESAAgCwE\nCAAgCwECAMhCgAAAshAgAIAsBAgAIAsBAgDIQoAAALL0t7qAajvttFMMDg62ugwA6BhjY2PrImKg\nFW23VYAMDg5qdHS01WUAQMewfV+r2mYICwCQhQABAGQhQAAAWQgQAEAWAgQAkIUAAYAWGVkyov4F\n/fJ8q39Bv0aWjLS6pGlpq8t4AaBXjCwZ0bmj5256PBETmx4vnLewVWVNCz0QAGiBRWOLprW+HREg\nANACEzExrfXtiAABgBboc9+01rcjAgQAWmB4zvC01rcjTqIDQAtUTpQvGlukiZhQn/s0PGe4Y06g\nS5IjotU1bDI0NBRMpggAW872WEQMtaJthrAAAFkIEABAFgIEAJCFAAEAZCFAAABZCBAAQBYCBACQ\nhQABAGQpNUBsP832Zbbvtn2X7X3KbA8A0DxlT2XyL5Kuioi32P4jSduV3B4AoElKCxDbO0raX9IJ\nkhQRj0h6pKz2AADNVeYQ1nMljUu60PZtts+zvX2J7QEAmqjMAOmXNFvSuRGxl6QNkj4yeSfbw7ZH\nbY+Oj4+XWA4AYCaVGSBrJK2JiOXp8WUqAuVxImJRRAxFxNDAwECJ5QAAZlJpARIRv5T0gO0Xp1UH\nSbqzrPYAAM1V9lVYp0i6JF2BdY+kE0tuDwDQJKUGSESskNSSG50AAMrFN9EBAFkIEABAFgIEAJCF\nAAEAZCFAAABZCBAAQBYCBACQhQABAGQhQAAAWQgQAEAWAgQAkIUAAQBkIUAAAFkIEABAFgIEAJCF\nAAEAZCFAAABZCBAAQBYCBACQhQABAGQhQAAAWQgQAEAWAgQAkIUAAQBkIUAAAFkIEABAlv4yX9z2\nvZIekjQh6bGIGCqzPQBA85QaIMlrI2JdE9oBADQRQ1gAgCxlB0hIutr2mO3hWjvYHrY9ant0fHy8\n5HIAADOl7AB5dUTMlnSYpPfa3n/yDhGxKCKGImJoYGCg5HIAADOl1ACJiJ+n/66VdLmkuWW2BwBo\nntICxPb2tp9SWZZ0iKRVZbUHAGiuMq/Ceoaky21X2vlaRFxVYnsAgCYqLUAi4h5JLy/r9QEArcVl\nvACALAQIACALAQIAyEKAAACyECAAgCwECAAgCwECAMhCgAAAshAgAIAsBAgAIAsBAgDIQoAAALIQ\nIACALAQIACALAQIAyEKAAACyECAAgCwECAAgCwECAMhCgAAAshAgAIAsBAgAIAsBAgDIQoAAALIQ\nIACALAQIACBL6QFiu8/2bbavLLstAEDzNKMHcqqku5rQDgCgiUoNENu7Spon6bwy2wEANF/ZPZDP\nS/obSRvr7WB72Pao7dHx8fGSywEAzJTSAsT24ZLWRsTYVPtFxKKIGIqIoYGBgbLKAQDMsDJ7IPtK\nOsL2vZIulXSg7YtLbA8A0ESlBUhEfDQido2IQUnHSLouIo4vqz0AQHPxPRAAQJYpA8T28VXL+07a\n9r4tbSQilkXE4dMvDwDQrhr1QE6vWv7CpG3vmOFaAAAdpFGAuM5yrccAgB7SKECiznKtxwCAHtLf\nYPtLbN+horfx/LSs9Ph5pVYGAGhrjQJkt6ZUAQDoOFMGSETcV/3Y9tMl7S/p/kbfMAcAdLdGl/Fe\naXuPtLyLpFUqrr76qu3TmlAfAKBNNTqJ/tyIWJWWT5R0TUT8haS9xWW8ANDTGgXIo1XLB0n6jiRF\nxEOaYoZdAED3a3QS/QHbp0haI2m2pKskyfa2kmaVXBsAoI016oGcJGl3SSdIOjoifpPWv1LShSXW\nBQBoc42uwlor6T011i+VtLSsogAA7W/KALF9xVTbI+KImS0HANApGp0D2UfSA5IWS1ou5r8CACSN\nAuRPJb1O0rGS/lrSEkmLI2J12YUBANrblCfRI2IiIq6KiLerOHH+U0nLpnMvEABAd2rUA5HtbSTN\nU9ELGZR0jqTLyy0LANDuGp1E/4qkPVR8gXB+1bfSAQA9rlEP5HhJGySdKun99qZz6JYUEfHUEmsD\nALSxRt8DafRFQwBAjyIgAABZCBAAQBYCBACQhQABAGQhQAAAWUoLENtPtn2z7dttr7Y9v6y2AADN\n1/Cb6FvhYUkHRsR627Mk/dD2f0fETSW2CQBoktICJCJC0vr0cFb6F2W1BwBorlLPgdjus71C0lpJ\n10TE8jLbAwA0T6kBkmbz3VPSrpLm2t5j8j62h22P2h4dHx8vsxwAwAxqylVY6V7qSyUdWmPboogY\nioihgYGBZpQDAJgBZV6FNWD7aWl5WxU3prq7rPYAAM1V5lVYu0j6su0+FUH1jYi4ssT2AABNVOZV\nWHdI2qus1wcAtBbfRAcAZCFAAABZCBAAQBYCBACQhQABAGQhQAAAWQgQAEAWAgQAkIUAAQBkIUAA\nAFkIEABAFgIEAJCFAAEAZCFAAABZCBAAQBYCBACQhQABAGQhQAAAWQgQAEAWAgQAkIUAAQBkIUAA\nAFkIEABAFgIEAJCFAAEAZCFAAABZSgsQ28+yvdT2nbZX2z61rLYAAM3XX+JrPybpgxFxq+2nSBqz\nfU1E3FlimwCAJimtBxIRv4iIW9PyQ5LukvTMstoDADRXU86B2B6UtJek5c1oDwBQvtIDxPYOkv5D\n0mkR8bsa24dtj9oeHR8fL7scAMAMKTVAbM9SER6XRMS3au0TEYsiYigihgYGBsosBwAwg8q8CsuS\nzpd0V0R8tqx2AACtUWYPZF9Jb5V0oO0V6d8bSmwPANBEpV3GGxE/lOSyXh8A0Fp8Ex0AkIUAAQBk\nIUAAAFkIEABAFgIEAJCFAAEAZCFAAABZCBAAQBYCBACQhQABAGQhQAAAWQgQAEAWAgQAkIUAAQBk\nIUAAAFkIEABAFgIEAJCFAAEAZCFAAABZCBAAQBYCBACQhQABAGQhQAAAWQgQAEAWAgQAkIUAAQBk\nKS1AbF9ge63tVWW1AQBonTJ7IBdJOrTE1wcAtFBpARIR35f0YFmv38tGloyof0G/PN/qX9CvkSUj\nrS4JQA/qb3UBmJ6RJSM6d/TcTY8nYmLT44XzFraqLAA9qOUn0W0P2x61PTo+Pt7qctreorFF01rf\nLeh1Ae2n5QESEYsiYigihgYGBlpdTtubiIlpre8GlV5X5RgrvS5CBGitlgcIpqfPfdNa3w16tdcF\ntLsyL+NdLOlGSS+2vcb2SWW11UuG5wxPa3036MVeF9AJyrwK69iI2CUiZkXErhFxfhnt9NrY+MJ5\nC3Xy0Mmbehx97tPJQyd39Qn0Xux19dr7WurNY+50HX0VVq9ekbRw3sKuPr7JhucMP+7nXL2+G/Xi\n+7oXj7kbOCJaXcMmQ0NDMTo6usX79y/orzmM0ec+Pfaxx2ayNLTYyJIRLRpbpImYUJ/7NDxnuGv/\nsPTi+7oXj3mm2B6LiKFWtN3RPRDGxntHL/W6evF93YvH3A06+iqsXhwbR/frxfd1Lx5zN+joAOnF\nK5LQ/Xrxfd2Lx9wNOnoIqzKk0Stj4+gNvfi+7sVj7gYdfRIdAHpdK0+id/QQFgCgdQgQAEAWAgQA\nkIUAAQBkIUAAAFna6ios2+OS7st8+k6S1s1gOZ2AY+5+vXa8Esc8Xc+JiJbcTKmtAmRr2B5t1aVs\nrcIxd79eO16JY+4kDGEBALIQIACALN0UIL14f1OOufv12vFKHHPH6JpzIACA5uqmHggAoIk6LkBs\nT9heYXu17dttf9D2k9K2A2xf2eoat5bt9VXLb7D9Y9vPsX2W7d/b3rnOvmH7M1WPP2T7rKYVvhWm\nqj0dd9h+QdX209K6ofT4Xtsr03tjhe1XNf0gpqnqvbzK9jdtb5fWh+2Lq/brtz1eeW/bPiE9rhzr\nV1p1DNNl++/T7+4dqfYzbZ89aZ89bd+Vlis/15W277T9CdtPbk31mKzjAkTSHyJiz4jYXdLrJB0m\n6cwW11QK2wdJOkfSYRFR+X7MOkkfrPOUhyW9yfZOzahvhjWqfaWkY6oeHyVp9aR9XpveG3tGxA1l\nFDnDKu/lPSQ9Iuk9af0GSXvY3jY9fp2kn0967terjvVtTap3q9jeR9LhkmZHxMskHSxpqaSjJ+16\njKTFVY9fGxEvlTRX0vMkfakJ5T6B7SHb57Si7XbViQGySUSslTQs6X223ep6ZpLt/SX9u6TDI+Jn\nVZsukHS07T+p8bTHVJyM+0ATSpxpjWr/tqS/lCTbz5f0W3XXl81+IOkFVY+/I2leWj5Wj/+D2ql2\nkbQuIh6WpIhYFxHfl/Rr23tX7fdXqnG8EbFeRci+sc77v1QRMRoR7292u81gO+veUB0dIJIUEfdI\n6pO0c6N9O8g2Kv5gvjEi7p60bb2KEDm1znP/TdJxtncssb6yTFX77yQ9YHsPFZ9Qv15jn6VpWGR5\nmUXOtPTLe5iKXlbFpZKOScM1L5M0+ZiOrhrCOrFJpW6tqyU9Kw3JLrT9mrR+sVLv0vYrJT0YET+p\n9QIR8TtJ/yvphTkF2H5bGj673fZXbQ/avi6tu9b2s9N+R6Whxdttfz+t2zREnoZVL7C9zPY9tt9f\n1cbxtm9OP5sv2fXvy2t7ve1PpWG979meW/WaR6R9+tI+t6Q6311Vz/W2/zPt/0+2j0ttr0wftDTF\nMV5k+4vp9+WTtn9ieyBte5Ltn1Ye19PxAdKlHpV0g6ST6mw/R9LbbT9l8ob0C/YVSR33SWkLar9U\nxR+aN0q6vMb2yhDW3jW2taNtba+QNCrpfknnVzZExB2SBlX0Pr5T47nVQ1gXNqPYrZV6EHNUjBqM\nS/q67RNUfBh4i4tzmZOHr2rJGm2wvbukMyQdGBEvV/Eh7AuSvpyG1C5R8bslSR+T9Pq03xF1XvIl\nkl6vYmjtTNuzbO+mYkhu34jYU9KEpOOmKGt7SdelIfmHJH1CxZDlkZIWpH1OkvTbiHiFpFdIepft\n56ZtL1fRK9tN0lslvSgi5ko6T9IpaZ96xyhJu0p6VUScLuniqloPlnR7RIxPUXvnB4jt56n4Ia1t\ndS0zaKOKbvxc2383eWNE/EbS1yS9t87zP6/iTbd9aRWWZ6rar1TxS3J/CptO94eqEDglIh6ZtP0K\nSZ9WdwxfSZIiYiIilkXEmZLeJ+nNEfGAil7FayS9WbV7l5Kk9KFpUNKPM5o/UNI3I2JdquVBSfuo\n+F2SpK9KenVa/pGki2y/S8UIRy1LIuLh9HprJT1D0kEqQvKW9OHgIBXnbep5RNJVaXmlpOsj4tG0\nPJjWHyLpben1lkt6ujb3wG6JiF+kYcGfqejladLz6x2j0v+PibR8gaTK+bR3SGr4waSj74meuldf\nlPSvERHddBokIn5ve56kH9j+VUScP2mXz0q6RTV+hhHxoO1vqPhDfEH51c6cqWpP/0/+Vnl/PDrR\nBZJ+ExErbR/Q6mK2lu0XS9pYNTy1pzZPnrpY0uck3RMRa+o8fwdJCyV9OyJ+XWatEfGedF5mnqQx\n23Nq7PZw1fKEit9Fq/i0/9EtbOrR2PxlvI2V14yIjVXnJSzplIj4bvUT03uiuoaNVY83asv+vm+o\nLETEA7Z/ZftAFb2qqXpOkjqzB7JtGltcLel7KhJ3ftX2g2yvqfq3T2vK3HrpE9Khks6ojIdWbVun\nYhhnmzpP/4yKGT47Ud3aI+LSiLi1yfW0RESsiYhuuupnB0lfdnE57h2S/lzSWWnbNyXtrtq9raW2\nV0m6WcVQ37sz279O0lG2ny5J6UT8Ddp8dd9xKi5mkO3nR8TyiPiYiuG2Z21hG9eqGI7budKG7edk\n1lvxXUkn256VXvNFtqczulDzGOs4T8VQVnXPpK6O64FERN0TUhGxTNK29bZ3iojYoWr5AUmV8c4r\nJu13uqTT6zzvV5K2K7fSmTNV7RFxVp3nHFC1PFhedeWoPuZG69N7e1lavkjSReVVVo6IGJNU8/s5\n6QPRrBrrB2ew/dW2/0HS9bYnJN2m4jzBhbY/rCIoKhckfMr2C1V8+r9W0u0qhtgatXGn7TMkXZ3O\n6TyqYqg59zYVUvFHfVDSrS6GWcZVnAfcUvWOsZYrVAxdbdF5NaYyAQBIKr7rIulzEbHfluzfcT0Q\nAMDMs/0RSSdrC859bHoOPRAAKFf6rsXk85VvjYiVtfbvFAQIACBLJ16FBQBoAwQIACALAYK25Kmn\n7d80K6rtbdIcQitsH217v/ScFd48m20Z9R3gaU4Z7y653QBQwVVYaFd/SHMJKX0p62uSnirpzIgY\nVTF/lCTtJUlV+35R0tkRcfETX/KJ0nX1joiN06zvABUTW3bCtPFAKeiBoO1Nnra/8kk+BcvFkl6R\nehzvVjGH2MdtXyJJtj9cNYvp/LRu0Pb/uLgR0yoVM8QeYvtG27e6uLnTDmnfe23PT+tX2n6J7UEV\nE9h9ILX7uGvmXcyoeqPt22zfkKbw0KR9Bmxfk3pL59m+z+leKLZPdzET7Crbp6V129teknpjq2xP\nvocG0HT0QNARIuIeF9Ni71y1bq3td0r6UEQcLm26adGVEXGZ7UNUTDo3V8U3iq9wcZ+V+9P6t0fE\nTekP9xmSDo6IDWm+rdO1eTbUdREx2/ZIauudqaezPiI+XaPcuyXtFxGP2T5Y0j+qmCSw2pkqZmE9\n2/ahSjMvpzmXTpS0d6p5ue3rVUzI938RMS/t14nT9aPLECDoZoekf7elxzuoCI77Jd0XETel9a9U\nMS/Tj4oRLf2RpBurXudb6b9jkt60Be3uqGLOpxdKCtWYokPFjKhHSlJEXGX711XrL4+IDZJk+1uS\n9lMxY+tnbP+zioCcaj4joCkIEHQEP37a/t229Gkqzoc87haoaQhqw6T9romIY+u8TmWG08qMq418\nXNLSiDgytbVsC+utKyJ+bHu2pDdI+oTtayNiQaPnAWXiHAjanidN2z+Np35X0juqzmc8szJL6iQ3\nSdrX9gvSftvbflGD135I0hNu6JXsqM33MD+hzj4/UnG+Rmmo7Y/T+h+ouGXrdmnG1SNVTOn/Z5J+\nny4O+JSk2Q3qA0pHgKBdNZq2v6GIuFrF1Vs32l4p6TLV+KOf7rp2gqTFLqYZv1HF3eam8l+Sjqx1\nEl3SJyWdbfs21e+xzJd0iItpyo+S9EtJD6Wp6i9SMXX5cknnRcRtkl4q6WYXNxU6U8Wd64CWYioT\noAVsbyNpIp1o30fSuZVLkYFOwTkQoDWeLekb6cuRj0h6V4vrAaaNHggAIAvnQAAAWQgQAEAWAgQA\nkIUAAQBkIUAAAFkIEABAlv8H/z6zOk63tPwAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f60392124a8>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "mses = [6.47, .957, .897, .804, .801, .79]\n",
    "algos = ['cosine_memory', 'KNN', \"NMF\", 'SVD', 'PMF', 'DL']\n",
    "plt.plot(algos, mses, 'go',  )\n",
    "plt.xlabel(\"Different algos\")\n",
    "plt.ylabel(\"MSE\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<a href='collaborating_filter.ipynb' target='_blank'>collaborating_filter.ipynb</a><br>"
      ],
      "text/plain": [
       "/home/ubuntu/collaborate_filter/collaborating_filter.ipynb"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "FileLink('collaborating_filter.ipynb')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
